{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d9f39e4-2a86-45bd-abac-98117aefa7ac",
   "metadata": {},
   "source": [
    "# Gemini Tool calling capabilities with Unity Catalog\n",
    "\n",
    "## Prerequisites\n",
    "\n",
    "**API Key**\n",
    "To run this tutorial, you will need an Gemini API key. For testing purposes, you can generate a new account, and use your evaluation test key (no credit card required!).\n",
    "\n",
    "Once you have acquired your key, set it to the environment variable `GOOGLE_API_KEY`.\n",
    "\n",
    "Below, we validate that this key is set properly in your environment.\n",
    "\n",
    "**Packages**\n",
    "\n",
    "To interface with both UnityCatalog and Gemini, you will need to install the following packages:\n",
    "\n",
    "```shell\n",
    "pip install unitycatalog-gemini\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "580c0b0f-3ab8-4d5a-9dea-0ee8c516c566",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "assert \"GOOGLE_API_KEY\" in os.environ, (\n",
    "    \"Please set the GOOGLE_API_KEY environment variable to your Gemini API key\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7fee661-4a32-425e-b42c-6f4efcb2eec8",
   "metadata": {},
   "source": [
    "## Configuration and Client setup\n",
    "\n",
    "In order to connect to your Unity Catalog server, you'll need an instance of the `ApiClient` from the `unitycatalog-client` package. \n",
    "\n",
    "> Note: If you don't already have a Catalog and a Schema created, be sure to create them before running this notebook and adjust the `CATALOG` and `SCHEMA` variables below to suit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eba0f1bf-ef95-420e-91cd-392e0a4a509f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from unitycatalog.ai.core.client import UnitycatalogFunctionClient\n",
    "from unitycatalog.ai.gemini.toolkit import UCFunctionToolkit\n",
    "from unitycatalog.client import ApiClient, Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9a93a22f-4d55-428e-8576-ee11be1c009d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Configuration()\n",
    "config.host = \"http://localhost:8080/api/2.1/unity-catalog\"\n",
    "\n",
    "# The base ApiClient is async\n",
    "api_client = ApiClient(configuration=config)\n",
    "\n",
    "client = UnitycatalogFunctionClient(api_client=api_client)\n",
    "\n",
    "CATALOG = \"AICatalog\"\n",
    "SCHEMA = \"AISchema\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d2ecec7-b8be-485b-9ead-157cb9265aad",
   "metadata": {},
   "source": [
    "## Define a function and register it to Unity Catalog\n",
    "\n",
    "In this next section, we'll be defining a placeholder Python function and creating it within Unity Catalog so that it can be retrieved and used as a tool by gemini's Claude model. \n",
    "\n",
    "There are a few things to keep in mind when creating functions for use with the `create_python_function` API:\n",
    "\n",
    "- Ensure that your have properly defined types for all arguments and for the return of the function.\n",
    "- Ensure that you have a Google-style docstring defined that includes descriptions for the function, each argument, and the return of the function. This is critical, as these are used to populate the metadata associated with the function within Unity Catalog, providing contextual data for an LLM to understand when and how to call the tool associated with this function.\n",
    "- If there are packages being called that are not part of core Python, ensure that the import statements are locally scoped (defined within the function body)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "99d7c4fa-c930-4824-aa2e-2ccf5ad73320",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetch_weather(location: str) -> str:\n",
    "    \"\"\"\n",
    "    Fetches the current weather in celsius for a given location.\n",
    "\n",
    "    Args:\n",
    "        location (str): The location to fetch the weather for.\n",
    "\n",
    "    Returns:\n",
    "        str: The current weather in celsius for the given location.\n",
    "    \"\"\"\n",
    "\n",
    "    return \"88.2 F\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9e2ecc06-78fa-423b-b6f4-86a649afb6cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionInfo(name='fetch_weather', catalog_name='AICatalog', schema_name='AISchema', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location', type_text='STRING', type_json='{\"name\": \"location\", \"type\": \"string\", \"nullable\": false, \"metadata\": {\"comment\": \"The location to fetch the weather for.\"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The location to fetch the weather for.')]), data_type=<ColumnTypeName.STRING: 'STRING'>, full_data_type='STRING', return_params=None, routine_body='EXTERNAL', routine_definition='return \"88.2 F\"', routine_dependencies=None, parameter_style='S', is_deterministic=True, sql_data_access='NO_SQL', is_null_call=False, security_type='DEFINER', specific_name='fetch_weather', comment='Fetches the current weather in celsius for a given location.', properties='null', full_name='AICatalog.AISchema.fetch_weather', owner=None, created_at=1734454495628, created_by=None, updated_at=1734454495628, updated_by=None, function_id='1251cfef-1a7a-46b9-b3c4-ae8ea9c1f421', external_language='PYTHON')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.create_python_function(func=fetch_weather, catalog=CATALOG, schema=SCHEMA, replace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3dbc6c5-362b-40fb-b561-f9f8c588793b",
   "metadata": {},
   "source": [
    "## Create a Toolkit instance of the function(s)\n",
    "\n",
    "Now that the function has been created within Unity Catalog, we can use the `unitycatalog-gemini` package to create a toolkit instance that gemini will 'understand' as a valid tool in its APIs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3d948b39-fa0b-4338-ba5d-7eba40ce947b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name='fetch_weather' catalog_name='AICatalog' schema_name='AISchema' input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location', type_text='STRING', type_json='{\"name\": \"location\", \"type\": \"string\", \"nullable\": false, \"metadata\": {\"comment\": \"The location to fetch the weather for.\"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=0, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The location to fetch the weather for.')]) data_type=<ColumnTypeName.STRING: 'STRING'> full_data_type='STRING' return_params=None routine_body='EXTERNAL' routine_definition='return \"88.2 F\"' routine_dependencies=None parameter_style='S' is_deterministic=True sql_data_access='NO_SQL' is_null_call=False security_type='DEFINER' specific_name='fetch_weather' comment='Fetches the current weather in celsius for a given location.' properties=None full_name='AICatalog.AISchema.fetch_weather' owner=None created_at=1734454495628 created_by=None updated_at=1734454495628 updated_by=None function_id='1251cfef-1a7a-46b9-b3c4-ae8ea9c1f421' external_language='PYTHON'\n"
     ]
    }
   ],
   "source": [
    "toolkit = UCFunctionToolkit(function_names=[f\"{CATALOG}.{SCHEMA}.fetch_weather\"], client=client)\n",
    "\n",
    "tools = toolkit.generate_callable_tool_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "81f780f3-02ea-454f-9578-2360c483dc34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<google.generativeai.types.content_types.CallableFunctionDeclaration at 0x11891c160>]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5f27777-e136-4cac-bf6b-1a9bbc89ca95",
   "metadata": {},
   "source": [
    "## Call Gemini with tools\n",
    "\n",
    "With our toolkit instance ready, we can now pass in our tools to a call to Gemini, giving Gemini the ability to request for 'answers' from tools that have been provided with the request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "617a5de6-a108-425c-a710-da4286efb075",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "response:\n",
       "GenerateContentResponse(\n",
       "    done=True,\n",
       "    iterator=None,\n",
       "    result=protos.GenerateContentResponse({\n",
       "      \"candidates\": [\n",
       "        {\n",
       "          \"content\": {\n",
       "            \"parts\": [\n",
       "              {\n",
       "                \"text\": \"The weather in Nome, AK is 88.2 F, and the weather in Death Valley, CA is 88.2 F.\\n\"\n",
       "              }\n",
       "            ],\n",
       "            \"role\": \"model\"\n",
       "          },\n",
       "          \"finish_reason\": \"STOP\",\n",
       "          \"safety_ratings\": [\n",
       "            {\n",
       "              \"category\": \"HARM_CATEGORY_HATE_SPEECH\",\n",
       "              \"probability\": \"NEGLIGIBLE\"\n",
       "            },\n",
       "            {\n",
       "              \"category\": \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n",
       "              \"probability\": \"NEGLIGIBLE\"\n",
       "            },\n",
       "            {\n",
       "              \"category\": \"HARM_CATEGORY_HARASSMENT\",\n",
       "              \"probability\": \"NEGLIGIBLE\"\n",
       "            },\n",
       "            {\n",
       "              \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n",
       "              \"probability\": \"NEGLIGIBLE\"\n",
       "            }\n",
       "          ],\n",
       "          \"avg_logprobs\": -0.009561144536541354\n",
       "        }\n",
       "      ],\n",
       "      \"usage_metadata\": {\n",
       "        \"prompt_token_count\": 167,\n",
       "        \"candidates_token_count\": 31,\n",
       "        \"total_token_count\": 198\n",
       "      }\n",
       "    }),\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Interface with Gemini via their SDK\n",
    "from google.generativeai import GenerativeModel\n",
    "\n",
    "multi = \"What's the weather in Nome, AK and in Death Valley, CA?\"\n",
    "\n",
    "\n",
    "model = GenerativeModel(model_name=\"gemini-2.0-flash-exp\", tools=tools)\n",
    "\n",
    "chat = model.start_chat(enable_automatic_function_calling=True)\n",
    "\n",
    "response = chat.send_message(\"What's the weather in Nome, AK and in Death Valley, CA?\")\n",
    "response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d91e129-3579-4cd5-9ad7-b40cc44fc407",
   "metadata": {},
   "source": [
    "## Show Details of Tool Calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9a080f0e-74ad-419b-9c85-39a5d4a74f35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "user -> [{'text': \"What's the weather in Nome, AK and in Death Valley, CA?\"}]\n",
      "--------------------------------------------------------------------------------\n",
      "model -> [{'function_call': {'name': 'fetch_weather', 'args': {'location': 'Nome, AK'}}}, {'function_call': {'name': 'fetch_weather', 'args': {'location': 'Death Valley, CA'}}}]\n",
      "--------------------------------------------------------------------------------\n",
      "user -> [{'function_response': {'name': 'fetch_weather', 'response': {'result': '{\"format\": \"SCALAR\", \"value\": \"88.2 F\"}'}}}, {'function_response': {'name': 'fetch_weather', 'response': {'result': '{\"format\": \"SCALAR\", \"value\": \"88.2 F\"}'}}}]\n",
      "--------------------------------------------------------------------------------\n",
      "model -> [{'text': 'The weather in Nome, AK is 88.2 F, and the weather in Death Valley, CA is 88.2 F.\\n'}]\n",
      "--------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for content in chat.history:\n",
    "    print(content.role, \"->\", [type(part).to_dict(part) for part in content.parts])  # noqa: T201\n",
    "    print(\"-\" * 80)  # noqa: T201"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cffb9fff",
   "metadata": {},
   "source": [
    "## Manually execute function calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66dc8e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.generativeai.types import content_types\n",
    "\n",
    "from unitycatalog.ai.gemini.utils import generate_tool_call_messages, get_function_calls\n",
    "\n",
    "history = []\n",
    "question = \"What's the weather in Nome, AK and in Death Valley, CA?\"\n",
    "\n",
    "\n",
    "content = content_types.to_content(question)\n",
    "if not content.role:\n",
    "    content.role = \"user\"\n",
    "\n",
    "history.append(content)\n",
    "\n",
    "response = model.generate_content(history)\n",
    "while function_calls := get_function_calls(response):\n",
    "    history, function_calls = generate_tool_call_messages(\n",
    "        model=model, response=response, conversation_history=history\n",
    "    )\n",
    "\n",
    "    response = model.generate_content(history)\n",
    "\n",
    "response"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
